import torch
import torch.nn as nn

"""
GRU（Gated Recurrent Unit）也称门控循环单元结构, 它也是传统RNN的变体, 
同LSTM一样能够有效捕捉长序列之间的语义关联, 缓解梯度消失或爆炸现象. 
同时它的结构和计算要比LSTM更简单, 它的核心结构可以分为两个部分去解析:
    更新门
    重置门
"""
"""
和之前分析过的LSTM中的门控一样, 首先计算更新门和重置门的门值, 分别是z(t)和r(t), 计算方法就是使用X(t)与h(t-1)拼接进行线性变换,
再经过sigmoid激活. 之后重置门门值作用在了h(t-1)上, 代表控制上一时间步传来的信息有多少可以被利用. 
接着就是使用这个重置后的h(t-1)进行基本的RNN计算, 即与x(t)拼接进行线性变化, 经过tanh激活, 得到新的h(t). 
最后更新门的门值会作用在新的h(t)，而1-门值会作用在h(t-1)上, 随后将两者的结果相加, 得到最终的隐含状态输出h(t), 
这个过程意味着更新门有能力保留之前的结果, 当门值趋于1时, 输出就是新的h(t), 而当门值趋于0时, 输出就是上一时间步的h(t-1). 
"""
"""
nn.GRU类初始化主要参数解释:
    input_size: 输入张量x中特征维度的大小.
    hidden_size: 隐层张量h中特征维度的大小.
    num_layers: 隐含层的数量.
    bidirectional: 是否选择使用双向LSTM, 如果为True, 则使用; 默认不使用.
"""

rnn = nn.GRU(5,6,2)
input001 = torch.randn(1,3,5) # (sequence_length, batch_size, input_size)
h0 = torch.randn(2, 3, 6)     # (num_layers * num_directions, batch_size, hidden_size)

output001, h1 = rnn(input001, h0)
print(output001)
print('-----------------------------')
print(h1)









